{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 16、PyTorch中进行卷积残差模块算子融合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import Conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight\n",
      "bias\n",
      "------\n",
      "Parameter containing:\n",
      "tensor([[[[-0.1311,  0.2241,  0.0320],\n",
      "          [ 0.0384,  0.1362,  0.1443],\n",
      "          [-0.1571,  0.0456,  0.2345]],\n",
      "\n",
      "         [[ 0.0272, -0.1808,  0.2062],\n",
      "          [-0.2220, -0.1941, -0.1617],\n",
      "          [ 0.1410,  0.1927, -0.2115]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0928,  0.1705, -0.1105],\n",
      "          [ 0.0279, -0.0829,  0.1401],\n",
      "          [ 0.0755,  0.1934, -0.1110]],\n",
      "\n",
      "         [[-0.1686, -0.1503,  0.1545],\n",
      "          [ 0.0384,  0.2028,  0.0521],\n",
      "          [-0.0017,  0.1650,  0.1246]]]], requires_grad=True)\n",
      "torch.Size([2, 2, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "# in_channels: int, 输入特征维度\n",
    "# out_channels: int, 输出特征维度\n",
    "# kernel_size: _size_2_t, 卷积核大小\n",
    "# stride: _size_2_t = 1, 步长，滑动跳几步\n",
    "# padding: Union[str, _size_2_t] = 0,分为 valid/same\n",
    "# dilation: _size_2_t = 1, 空洞卷积，原本是实心的，变相的变大感受范围\n",
    "# groups: int = 1, 分为多少组，必须被 in/out channels整除。 深度可分离卷积用\n",
    "# bias: bool = True,\n",
    "# padding_mode: str = \"zeros\",  # TODO: refine this type\n",
    "\n",
    "conv_laye1 = Conv2d(2, 2, 3, padding=\"same\")\n",
    "for i in conv_laye1.named_parameters():\n",
    "    print(i[0])\n",
    "print('------')\n",
    "print(conv_laye1.weight)\n",
    "print(conv_laye1.weight.shape)\n",
    "# torch.Size([2, 2, 3, 3]) 分别是 out_channels, in_channels, kernel_width, kernel_height"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 1, 3, 3]), torch.Size([4]))"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# group 用法\n",
    "conv_laye1 =  Conv2d(2, 4, 3, padding=\"same\", groups=2)\n",
    "# 其实就是不同的卷积输出然后拼接起来\n",
    "conv_laye1.weight.shape, conv_laye1.bias.shape # (torch.Size([4, 1, 3, 3]), torch.Size([4]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([2, 1, 3, 3]), torch.Size([2]))"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# group 用法\n",
    "conv_laye1 =  Conv2d(1, 2, 3, padding=\"same\")\n",
    "# 其实就是不同的卷积输出然后拼接起来\n",
    "conv_laye1.weight.shape, conv_laye1.bias.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 深度可分离卷积\n",
    "\n",
    "`res_block =  3*3 conv + 1*1 conv + input`\n",
    "\n",
    "![image-20250619202922931](http://assets.hypervoid.top/img/2025/06/19/image-20250619202922931-79dd.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "in_channels = 2\n",
    "out_channels = 2\n",
    "kernel_size = 3 \n",
    "batch_size = 1\n",
    "w, h = 28, 28 # 图片宽高"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 原生写法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 2, 28, 28]),\n",
       " tensor([[[[0.4720, 1.1708, 1.1708,  ..., 1.1708, 1.1708, 1.3344],\n",
       "           [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "           [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "           ...,\n",
       "           [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "           [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "           [0.7329, 1.0327, 1.0327,  ..., 1.0327, 1.0327, 0.7811]],\n",
       " \n",
       "          [[1.1891, 1.3983, 1.3983,  ..., 1.3983, 1.3983, 1.0062],\n",
       "           [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "           [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "           ...,\n",
       "           [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "           [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "           [0.4166, 0.8005, 0.8005,  ..., 0.8005, 0.8005, 0.8232]]]],\n",
       "        grad_fn=<AddBackward0>))"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.ones((batch_size, in_channels, w, h))\n",
    "conv3 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=\"same\")\n",
    "conv1 = nn.Conv2d(in_channels, out_channels, 1)\n",
    "result1 = conv3(x) + conv1(x) + x\n",
    "result1.shape, result1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 算子融合加速写法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0.4720, 1.1708, 1.1708,  ..., 1.1708, 1.1708, 1.3344],\n",
       "          [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "          [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "          ...,\n",
       "          [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "          [0.8485, 1.5441, 1.5441,  ..., 1.5441, 1.5441, 1.5185],\n",
       "          [0.7329, 1.0327, 1.0327,  ..., 1.0327, 1.0327, 0.7811]],\n",
       "\n",
       "         [[1.1891, 1.3983, 1.3983,  ..., 1.3983, 1.3983, 1.0062],\n",
       "          [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "          [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "          ...,\n",
       "          [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "          [0.9465, 1.5522, 1.5522,  ..., 1.5522, 1.5522, 1.2437],\n",
       "          [0.4166, 0.8005, 0.8005,  ..., 0.8005, 0.8005, 0.8232]]]],\n",
       "       grad_fn=<ConvolutionBackward0>)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把 conv1 和 x 自身写作 3*3 卷积形式\n",
    "# 最后简化为一个 conv 层\n",
    "\n",
    "# 8个元素，对应4个维度（每个维度有 left/top/front 和 right/bottom/back 两个填充量）\n",
    "# 维度分别是 width、height、in_channels、out_channels\n",
    "conv1_w = F.pad(conv1.weight, [1, 1, 1, 1, 0, 0, 0, 0])\n",
    "\n",
    "conv_x_w = torch.zeros_like(conv1_w)\n",
    "conv_x_w[0, 0, 1, 1] = 1  # 设置通道为x自身\n",
    "conv_x_w[1, 1, 1, 1] = 1  # 设置通道为x自身\n",
    "\n",
    "\n",
    "final_conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=\"same\")\n",
    "final_conv.weight = nn.Parameter(conv3.weight + conv1_w + conv_x_w)\n",
    "final_conv.bias = nn.Parameter(conv3.bias + conv1.bias + 0)\n",
    "\n",
    "result2 = final_conv(x)\n",
    "result2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.all(torch.isclose(result1, result2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 比较运行效率\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0009324550628662109\n",
      "8.726119995117188e-05\n",
      "10.685792349726777\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "x = torch.ones((batch_size*100, in_channels, w, h))\n",
    "\n",
    "t1 = time.time()\n",
    "x = torch.ones((batch_size, in_channels, w, h))\n",
    "conv3 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=\"same\")\n",
    "conv1 = nn.Conv2d(in_channels, out_channels, 1)\n",
    "result1 = conv3(x) + conv1(x) + x\n",
    "t2 = time.time()\n",
    "print(t2-t1)\n",
    "\n",
    "\n",
    "conv1_w = F.pad(conv1.weight, [1, 1, 1, 1, 0, 0, 0, 0])\n",
    "\n",
    "conv_x_w = torch.zeros_like(conv1_w)\n",
    "conv_x_w[0, 0, 1, 1] = 1  # 设置通道为x自身\n",
    "conv_x_w[1, 1, 1, 1] = 1  # 设置通道为x自身\n",
    "\n",
    "\n",
    "final_conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=\"same\")\n",
    "final_conv.weight = nn.Parameter(conv3.weight + conv1_w + conv_x_w)\n",
    "final_conv.bias = nn.Parameter(conv3.bias + conv1.bias + 0)\n",
    "t3 = time.time()\n",
    "result2 = final_conv(x)\n",
    "t4 = time.time()\n",
    "print(t4-t3)\n",
    "\n",
    "print((t2-t1)/ (t4-t3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
